热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

词表|多头_猿创征文|深度学习前沿应用文本生成

篇首语:本文由编程笔记#小编为大家整理,主要介绍了猿创征文|深度学习前沿应用文本生成相关的知识,希望对你有一定的参考价值。 猿创征文|【深度学习前沿应用】文本生成 作者简介&#

篇首语:本文由编程笔记#小编为大家整理,主要介绍了猿创征文|深度学习前沿应用文本生成相关的知识,希望对你有一定的参考价值。



猿创征文|【深度学习前沿应用】文本生成





作者简介:在校大学生一枚,C/C++领域新星创作者,华为云享专家,阿里云专家博主,腾云先锋(TDP)成员,云曦智划项目总负责人,全国高等学校计算机教学与产业实践资源建设专家委员会(TIPCC)志愿者,以及编程爱好者,期待和大家一起学习,一起进步~
.
博客主页:ぃ灵彧が的学习日志
.
本文专栏:人工智能
.
专栏寄语:若你决定灿烂,山无遮,海无拦
.



文章目录


  • 猿创征文|【深度学习前沿应用】文本生成
  • 前言
      • 什么是文本生成?


  • 一、数据加载及预处理
    • (一)、数据加载
    • (二)、构建词表
    • (三)、创建指定数据格式

  • 二、模型配置
    • (一)、定义网络超参数
    • (二)、定义编码器
    • (三)、定义解码器

  • 三、模型训练
  • 四、模型预测
  • 总结




前言

什么是文本生成?

在自然语言处理领域,文本生成任务是指根据给定的输入,自动生成对应的输出,典型的任务包含:机器翻译、智能问答等。文本生成任务在注意力机制提出之后取得了显著的效果,尤其是在2018年基于多头注意力机制的Transformer(原理如下图1所示)在机器翻译领域取得当时最优效果时,基于Transformer的文本生成任务也进入了新的繁荣时期。

本实验的目的是演示如何使用经典的Transformer实现英-中机器翻译,实验平台为百度AI Studio,实验环境为Python3.7,Paddle2.0。




一、数据加载及预处理


(一)、数据加载

本实验选用开源的小型英-中翻译CMN数据集,该数据集中包含样本总数24360条,均为短文本,部分数据展示如下图2所示:

不同于图像处理,在处理自然语言时,需要指定文本的长度,便于进行批量计算,因此,在数据预处理阶段,应该先统计数据集中文本的长度,然后指定一个恰当的值,进行统一处理。




  1. 导入相关包

import paddle
import paddle.nn.functional as F
import re
import numpy as np
print(paddle.__version__)
# cpu/gpu环境选择,在 paddle.set_device() 输入对应运行设备。
# device = paddle.set_device('gpu')



  1. 统计数据集中句子的长度等信息

# 统计数据集中句子的长度等信息
lines = open('data/data78721/cmn.txt','r',encoding='utf-8').readlines()
print(len(lines))
datas = []
dic_en =
dic_cn =
for line in lines:
ll = line.strip().split('\\t')
if len(ll)<2:
continue
datas.append([ll[0].lower().split(&#39; &#39;)[1:-1],list(ll[1])])
# print(ll[0])
if len(ll[0].split(&#39; &#39;)) not in dic_en:
dic_en[len(ll[0].split(&#39; &#39;))] &#61; 1
else:
dic_en[len(ll[0].split(&#39; &#39;))] &#43;&#61;1
if len(ll[1]) not in dic_cn:
dic_cn[len(ll[1])] &#61; 1
else:
dic_cn[len(ll[1])] &#43;&#61;1
keys_en &#61; list(dic_en.keys())
keys_en.sort()
count &#61; 0
# print(&#39;英文长度统计&#xff1a;&#39;)
for k in keys_en:
count &#43;&#61; dic_en[k]
# print(k,dic_en[k],count/len(lines))
keys_cn &#61; list(dic_cn.keys())
keys_cn.sort()
count &#61; 0
# print(&#39;中文长度统计&#xff1a;&#39;)
for k in keys_cn:
count &#43;&#61; dic_cn[k]
# print(k,dic_cn[k],count/len(lines))

en_length &#61; 10
cn_length &#61; 10



(二)、构建词表

对于中英文&#xff0c;需要分别构建词表&#xff0c;进行词向量学习&#xff0c;除此之外&#xff0c;还需要在每个词表中加入开始符号、结束符合以及填充符号&#xff1a;

# 构建中英文词表
en_vocab &#61;
cn_vocab &#61;
en_vocab[&#39;&#39;], en_vocab[&#39;&#39;], en_vocab[&#39;&#39;] &#61; 0, 1, 2
cn_vocab[&#39;&#39;], cn_vocab[&#39;&#39;], cn_vocab[&#39;&#39;] &#61; 0, 1, 2
en_idx, cn_idx &#61; 3, 3
for en, cn in datas:
# print(en,cn)
for w in en:
if w not in en_vocab:
en_vocab[w] &#61; en_idx
en_idx &#43;&#61; 1
for w in cn:
if w not in cn_vocab:
cn_vocab[w] &#61; cn_idx
cn_idx &#43;&#61; 1
print(len(list(en_vocab)))
print(len(list(cn_vocab)))
&#39;&#39;&#39;
英文词表长度&#xff1a;6057
中文词表长度&#xff1a;3533
&#39;&#39;&#39;




(三)、创建指定数据格式

需要将输入英文与输出中文封装为指定格式&#xff0c;即为编码器端输入添加结束符号并填充至固定长度&#xff0c;为解码器输入添加开始、结束符号并填充至固定长度&#xff0c;解码器端输出的正确答案应该只添加结束符号并且填充至固定长度。

padded_en_sents &#61; []
padded_cn_sents &#61; []
padded_cn_label_sents &#61; []
for en, cn in datas:
if len(en)>en_length:
en &#61; en[:en_length]
if len(cn)>cn_length:
cn &#61; cn[:cn_length]
padded_en_sent &#61; en &#43; [&#39;&#39;] &#43; [&#39;&#39;] * (en_length - len(en))
padded_en_sent.reverse()
padded_cn_sent &#61; [&#39;&#39;] &#43; cn &#43; [&#39;&#39;] &#43; [&#39;&#39;] * (cn_length - len(cn))
padded_cn_label_sent &#61; cn &#43; [&#39;&#39;] &#43; [&#39;&#39;] * (cn_length - len(cn) &#43; 1)

padded_en_sents.append(np.array([en_vocab[w] for w in padded_en_sent]))
padded_cn_sents.append(np.array([cn_vocab[w] for w in padded_cn_sent]) )
padded_cn_label_sents.append(np.array([cn_vocab[w] for w in padded_cn_label_sent]))
train_en_sents &#61; np.array(padded_en_sents)
train_cn_sents &#61; np.array(padded_cn_sents)
train_cn_label_sents &#61; np.array(padded_cn_label_sents)

print(train_en_sents.shape)
print(train_cn_sents.shape)
print(train_cn_label_sents.shape)



二、模型配置


(一)、定义网络超参数

embedding_size &#61; 128
hidden_size &#61; 512
num_encoder_lstm_layers &#61; 1
en_vocab_size &#61; len(list(en_vocab))
cn_vocab_size &#61; len(list(cn_vocab))
epochs &#61; 20
batch_size &#61; 16



(二)、定义编码器

# encoder: simply learn representation of source sentence
class Encoder(paddle.nn.Layer):
def __init__(self,en_vocab_size, embedding_size,num_layers&#61;2,head_number&#61;2,middle_units&#61;512):
super(Encoder, self).__init__()
self.emb &#61; paddle.nn.Embedding(en_vocab_size, embedding_size,)
"""
d_model (int) - 输入输出的维度。
nhead (int) - 多头注意力机制的Head数量。
dim_feedforward (int) - 前馈神经网络中隐藏层的大小。
"""

encoder_layer &#61; paddle.nn.TransformerEncoderLayer(embedding_size, head_number, middle_units)
self.encoder &#61; paddle.nn.TransformerEncoder(encoder_layer, num_layers)
def forward(self, x):
x &#61; self.emb(x)
en_out &#61; self.encoder(x)
return en_out



(三)、定义解码器

class Decoder(paddle.nn.Layer):
def __init__(self,cn_vocab_size, embedding_size,num_layers&#61;2,head_number&#61;2,middle_units&#61;512):
super(Decoder, self).__init__()
self.emb &#61; paddle.nn.Embedding(cn_vocab_size, embedding_size)

decoder_layer &#61; paddle.nn.TransformerDecoderLayer(embedding_size, head_number, middle_units)
self.decoder &#61; paddle.nn.TransformerDecoder(decoder_layer, num_layers)

# for computing output logits
self.outlinear &#61;paddle.nn.Linear(embedding_size, cn_vocab_size)
def forward(self, x, encoder_outputs):
x &#61; self.emb(x)
# dec_input, enc_output,self_attn_mask, cross_attn_mask
de_out &#61; self.decoder(x, encoder_outputs)
output &#61; self.outlinear(de_out)
output &#61; paddle.squeeze(output)
return output

三、模型训练

encoder &#61; Encoder(en_vocab_size, embedding_size)
decoder &#61; Decoder(cn_vocab_size, embedding_size)
opt &#61; paddle.optimizer.Adam(learning_rate&#61;0.0001,
parameters&#61;encoder.parameters() &#43; decoder.parameters())
for epoch in range(epochs):
print("epoch:".format(epoch))
# shuffle training data
perm &#61; np.random.permutation(len(train_en_sents))
train_en_sents_shuffled &#61; train_en_sents[perm]
train_cn_sents_shuffled &#61; train_cn_sents[perm]
train_cn_label_sents_shuffled &#61; train_cn_label_sents[perm]
# print(train_en_sents_shuffled.shape[0],train_en_sents_shuffled.shape[1])
for iteration in range(train_en_sents_shuffled.shape[0] // batch_size):
x_data &#61; train_en_sents_shuffled[(batch_size*iteration):(batch_size*(iteration&#43;1))]
sent &#61; paddle.to_tensor(x_data)
en_repr &#61; encoder(sent)
x_cn_data &#61; train_cn_sents_shuffled[(batch_size*iteration):(batch_size*(iteration&#43;1))]
x_cn_label_data &#61; train_cn_label_sents_shuffled[(batch_size*iteration):(batch_size*(iteration&#43;1))]

loss &#61; paddle.zeros([1])
for i in range( cn_length &#43; 2):
cn_word &#61; paddle.to_tensor(x_cn_data[:,i:i&#43;1])
cn_word_label &#61; paddle.to_tensor(x_cn_label_data[:,i])
logits &#61; decoder(cn_word, en_repr)
step_loss &#61; F.cross_entropy(logits, cn_word_label)
loss &#43;&#61; step_loss
loss &#61; loss / (cn_length &#43; 2)
if(iteration % 50 &#61;&#61; 0):
print("iter , loss:".format(iteration, loss.numpy()))
loss.backward()
opt.step()
opt.clear_grad()

输出结果如下图3所示&#xff1a;




四、模型预测

encoder.eval()
decoder.eval()
num_of_exampels_to_evaluate &#61; 10
indices &#61; np.random.choice(len(train_en_sents), num_of_exampels_to_evaluate, replace&#61;False)
x_data &#61; train_en_sents[indices]
sent &#61; paddle.to_tensor(x_data)
en_repr &#61; encoder(sent)
word &#61; np.array(
[[cn_vocab[&#39;&#39;]]] * num_of_exampels_to_evaluate
)
word &#61; paddle.to_tensor(word)

decoded_sent &#61; []
for i in range(cn_length &#43; 2):
logits &#61; decoder(word, en_repr)
word &#61; paddle.argmax(logits, axis&#61;1)
decoded_sent.append(word.numpy())
word &#61; paddle.unsqueeze(word, axis&#61;-1)
results &#61; np.stack(decoded_sent, axis&#61;1)
for i in range(num_of_exampels_to_evaluate):
print(&#39;---------------------&#39;)
en_input &#61; " ".join(datas[indices[i]][0])
ground_truth_translate &#61; "".join(datas[indices[i]][1])
model_translate &#61; ""
for k in results[i]:
w &#61; list(cn_vocab)[k]
if w !&#61; &#39;&#39; and w !&#61; &#39;&#39;:
model_translate &#43;&#61; w
print(en_input)
print("true: ".format(ground_truth_translate))
print("pred: ".format(model_translate))

输出结果如下图4所示&#xff1a;




总结

本系列文章内容为根据清华社出版的《机器学习实践》所作的相关笔记和感悟&#xff0c;其中代码均为基于百度飞桨开发&#xff0c;若有任何侵权和不妥之处&#xff0c;请私信于我&#xff0c;定积极配合处理&#xff0c;看到必回&#xff01;&#xff01;&#xff01;

最后&#xff0c;引用本次活动的一句话&#xff0c;来作为文章的结语&#xff5e;(&#xffe3;▽&#xffe3;&#xff5e;)~&#xff1a;

学习的最大理由是想摆脱平庸&#xff0c;早一天就多一份人生的精彩&#xff1b;迟一天就多一天平庸的困扰。

ps&#xff1a;更多精彩内容还请进入本文专栏&#xff1a;人工智能&#xff0c;进行查看&#xff0c;欢迎大家支持与指教啊&#xff5e;(&#xffe3;▽&#xffe3;&#xff5e;)~


推荐阅读
  • PHP中元素的计量单位是什么? ... [详细]
  • Python与R语言在功能和应用场景上各有优势。尽管R语言在统计分析和数据可视化方面具有更强的专业性,但Python作为一种通用编程语言,适用于更广泛的领域,包括Web开发、自动化脚本和机器学习等。对于初学者而言,Python的学习曲线更为平缓,上手更加容易。此外,Python拥有庞大的社区支持和丰富的第三方库,使其在实际应用中更具灵活性和扩展性。 ... [详细]
  • Understanding the Distinction Between decodeURIComponent and Its Encoding Counterpart
    本文探讨了 JavaScript 中 `decodeURIComponent` 和其编码对应函数之间的区别。通过详细分析这两个函数的功能和应用场景,帮助开发者更好地理解和使用它们,避免常见的编码和解码错误。 ... [详细]
  • Java服务问题快速定位与解决策略全面指南 ... [详细]
  • BZOJ4240 Gym 102082G:贪心算法与树状数组的综合应用
    BZOJ4240 Gym 102082G 题目 "有趣的家庭菜园" 结合了贪心算法和树状数组的应用,旨在解决在有限时间和内存限制下高效处理复杂数据结构的问题。通过巧妙地运用贪心策略和树状数组,该题目能够在 10 秒的时间限制和 256MB 的内存限制内,有效处理大量输入数据,实现高性能的解决方案。提交次数为 756 次,成功解决次数为 349 次,体现了该题目的挑战性和实际应用价值。 ... [详细]
  • 在CentOS上部署和配置FreeSWITCH
    在CentOS系统上部署和配置FreeSWITCH的过程涉及多个步骤。本文详细介绍了从源代码安装FreeSWITCH的方法,包括必要的依赖项安装、编译和配置过程。此外,还提供了常见的配置选项和故障排除技巧,帮助用户顺利完成部署并确保系统的稳定运行。 ... [详细]
  • 本文作为“实现简易版Spring系列”的第五篇,继前文深入探讨了Spring框架的核心技术之一——控制反转(IoC)之后,将重点转向另一个关键技术——面向切面编程(AOP)。对于使用Spring框架进行开发的开发者来说,AOP是一个不可或缺的概念。了解AOP的背景及其基本原理,对于掌握这一技术至关重要。本文将通过具体示例,详细解析AOP的实现机制,帮助读者更好地理解和应用这一技术。 ... [详细]
  • HBase在金融大数据迁移中的应用与挑战
    随着最后一台设备的下线,标志着超过10PB的HBase数据迁移项目顺利完成。目前,新的集群已在新机房稳定运行超过两个月,监控数据显示,新集群的查询响应时间显著降低,系统稳定性大幅提升。此外,数据消费的波动也变得更加平滑,整体性能得到了显著优化。 ... [详细]
  • 本文详细解析了 MySQL 5.7.20 版本中二进制日志(binlog)崩溃恢复机制的工作流程。假设使用 InnoDB 存储引擎,并且启用了 `sync_binlog=1` 配置,文章深入探讨了在系统崩溃后如何通过 binlog 进行数据恢复,确保数据的一致性和完整性。 ... [详细]
  • MySQL性能优化与调参指南【数据库管理】
    本文详细探讨了MySQL数据库的性能优化与参数调整技巧,旨在帮助数据库管理员和开发人员提升系统的运行效率。内容涵盖索引优化、查询优化、配置参数调整等方面,结合实际案例进行深入分析,提供实用的操作建议。此外,还介绍了常见的性能监控工具和方法,助力读者全面掌握MySQL性能优化的核心技能。 ... [详细]
  • 利用 JavaScript 实现定时任务的高效执行方法(代码可直接复用) ... [详细]
  • 在处理大规模并发请求时,传统的多线程或多进程模型往往无法有效解决性能瓶颈问题。尽管它们在处理小规模任务时能提升效率,但在高并发场景下,系统资源的过度消耗和上下文切换的开销会显著降低整体性能。相比之下,Python 的 `asyncio` 模块通过协程提供了一种轻量级且高效的并发解决方案。本文将深入解析 `asyncio` 模块的原理及其在实际应用中的优化技巧,帮助开发者更好地利用协程技术提升程序性能。 ... [详细]
  • 2019年后蚂蚁集团与拼多多面试经验详述与深度剖析
    2019年后蚂蚁集团与拼多多面试经验详述与深度剖析 ... [详细]
  • 本文详细介绍了HDFS的基础知识及其数据读写机制。首先,文章阐述了HDFS的架构,包括其核心组件及其角色和功能。特别地,对NameNode进行了深入解析,指出其主要负责在内存中存储元数据、目录结构以及文件块的映射关系,并通过持久化方案确保数据的可靠性和高可用性。此外,还探讨了DataNode的角色及其在数据存储和读取过程中的关键作用。 ... [详细]
  • NVIDIA最新推出的Ampere架构标志着显卡技术的一次重大突破,不仅在性能上实现了显著提升,还在能效比方面进行了深度优化。该架构融合了创新设计与技术改进,为用户带来更加流畅的图形处理体验,同时降低了功耗,提升了计算效率。 ... [详细]
author-avatar
ecrbw_9870105634
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有